{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "550e0040-da75-4a4f-8d2d-664029540d51",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import patches\n",
    "import torch\n",
    "import torchvision\n",
    "from pol.datasets.objdetect import COCODataset, collate_padding_fn\n",
    "from torch.utils.data import DataLoader, Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6a21657c-49ed-47df-8ead-9a50688f5392",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MiniDETR(torch.nn.Module):\n",
    "    def __init__(self, hidden_dim=256, \n",
    "                 use_transformer=True, nheads=8,\n",
    "                 num_encoder_layers=6, num_decoder_layers=6):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.use_transformer = use_transformer\n",
    "        self.backbone = torchvision.models.resnet50(pretrained=True)\n",
    "        del self.backbone.fc\n",
    "        \n",
    "        self.conv = torch.nn.Conv2d(2048, hidden_dim, 1)\n",
    "        \n",
    "        if self.use_transformer:\n",
    "            self.transformer = torch.nn.Transformer(\n",
    "                hidden_dim, nheads, num_encoder_layers, num_decoder_layers\n",
    "            )\n",
    "            self.query_pos = torch.nn.Parameter(torch.rand(100, hidden_dim))\n",
    "\n",
    "            self.row_embed = torch.nn.Parameter(torch.rand(50, hidden_dim // 2))\n",
    "            self.col_embed = torch.nn.Parameter(torch.rand(50, hidden_dim // 2))\n",
    "            self.final_fc = torch.nn.Linear(100, 1)\n",
    "        else:\n",
    "            self.pooling = torch.nn.AdaptiveAvgPool2d((1, 1))\n",
    "        \n",
    "    def forward(self, inputs):\n",
    "        '''\n",
    "        Args:\n",
    "            inputs: BxHxW, batched images\n",
    "        \n",
    "        Returns:\n",
    "            BxD, embedded latent codes\n",
    "        '''\n",
    "        x = self.backbone.conv1(inputs)\n",
    "        x = self.backbone.bn1(x)\n",
    "        x = self.backbone.relu(x)\n",
    "        x = self.backbone.maxpool(x)\n",
    "\n",
    "        x = self.backbone.layer1(x)\n",
    "        x = self.backbone.layer2(x)\n",
    "        x = self.backbone.layer3(x)\n",
    "        x = self.backbone.layer4(x)\n",
    "        \n",
    "        # x is Bx2048xH'xW'\n",
    "        h = self.conv(x) # BxDxH'xW'\n",
    "        \n",
    "        if self.use_transformer:\n",
    "            B, _, H, W = h.shape\n",
    "            pos = torch.cat([\n",
    "                self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),\n",
    "                self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),\n",
    "            ], dim=-1).flatten(0, 1).unsqueeze(1)\n",
    "            # pos: H'W'x1xD\n",
    "            h = 0.1 * h.flatten(2).permute(2, 0, 1) # H'W'xBxD\n",
    "            h = self.transformer(pos + h,\n",
    "                                 self.query_pos.unsqueeze(1).repeat(\n",
    "                                     1, B,\n",
    "                                     1)) # 100xBxD\n",
    "            h = self.final_fc(h.permute(1, 2, 0)) # BxDx1\n",
    "            return h.squeeze(-1)\n",
    "        else:\n",
    "            h = self.pooling(h) # BxDx1x1\n",
    "            return h.squeeze(-1).squeeze(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6b180bd0-a507-4687-89f4-c243f6fe0d11",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading split 'validation' to '/home/lingxiao/fiftyone/coco-2017/validation' if necessary\n",
      "Found annotations at '/home/lingxiao/fiftyone/coco-2017/raw/instances_val2017.json'\n",
      "Sufficient images already downloaded\n",
      "Existing download of split 'validation' is sufficient\n",
      "Loading 'coco-2017' split 'validation'\n",
      " 100% |█████████████████| 349/349 [873.6ms elapsed, 0s remaining, 399.5 samples/s]      \n",
      "Dataset 'coco-2017-validation' created\n",
      "torch.Size([4, 256])\n"
     ]
    }
   ],
   "source": [
    "img_encoder = MiniDETR(use_transformer=False)\n",
    "dataset = COCODataset(split='validation', max_num_detection=10)\n",
    "dataloader = DataLoader(dataset, batch_size=4, shuffle=True, \n",
    "                        collate_fn=collate_padding_fn, drop_last=True)\n",
    "\n",
    "for batch_idx, data in enumerate(dataloader):\n",
    "    batch_img = data['img']\n",
    "    output = img_encoder(batch_img)\n",
    "    print(output.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76ab9ba1-caa2-44dd-99b5-4cc99f067fb7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
